!--------------------------------------------------------------------------------------------------!
! Copyright (C) by the DBCSR developers group - All rights reserved                                !
! This file is part of the DBCSR library.                                                          !
!                                                                                                  !
! For information on the license, see the LICENSE file.                                            !
! For further information please visit https://dbcsr.cp2k.org                                      !
! SPDX-License-Identifier: GPL-2.0+                                                                !
!--------------------------------------------------------------------------------------------------!

MODULE dbcsr_csr_conversions
   !! DBCSR to CSR matrix format conversion
   USE dbcsr_block_access, ONLY: dbcsr_put_block
   USE dbcsr_data_methods, ONLY: dbcsr_data_clear_pointer, &
                                 dbcsr_data_init, &
                                 dbcsr_data_new, &
                                 dbcsr_data_release
   USE dbcsr_data_types, ONLY: dbcsr_type_complex_4, &
                               dbcsr_type_complex_8, &
                               dbcsr_type_real_4, &
                               dbcsr_type_real_8, &
                               dbcsr_type_real_default
   USE dbcsr_dist_methods, ONLY: dbcsr_distribution_col_dist, &
                                 dbcsr_distribution_mp, &
                                 dbcsr_distribution_new, &
                                 dbcsr_distribution_release
   USE dbcsr_iterator_operations, ONLY: dbcsr_iterator_blocks_left, &
                                        dbcsr_iterator_next_block, &
                                        dbcsr_iterator_start, &
                                        dbcsr_iterator_stop
   USE dbcsr_kinds, ONLY: default_string_length, &
                          dp, &
                          int_8, &
                          real_4, &
                          real_8, &
                          sp
   USE dbcsr_methods, ONLY: &
      dbcsr_col_block_sizes, dbcsr_distribution, dbcsr_get_data_type, dbcsr_get_num_blocks, &
      dbcsr_get_nze, dbcsr_has_symmetry, dbcsr_name, dbcsr_nblkcols_total, dbcsr_nblkrows_total, &
      dbcsr_nfullrows_local, dbcsr_release, dbcsr_row_block_sizes, dbcsr_valid_index
   USE dbcsr_mp_methods, ONLY: dbcsr_mp_group, &
                               dbcsr_mp_mynode, &
                               dbcsr_mp_new, &
                               dbcsr_mp_numnodes, &
                               dbcsr_mp_release
   USE dbcsr_mpiwrap, ONLY: mp_environ, &
                            mp_gather, &
                            mp_recv, &
                            mp_send, &
                            mp_sum, mp_comm_type
   USE dbcsr_operations, ONLY: dbcsr_copy, &
                               dbcsr_get_info, &
                               dbcsr_set
   USE dbcsr_transformations, ONLY: dbcsr_complete_redistribute, &
                                    dbcsr_desymmetrize_deep
   USE dbcsr_types, ONLY: dbcsr_data_obj, &
                          dbcsr_distribution_obj, &
                          dbcsr_iterator, &
                          dbcsr_mp_obj, &
                          dbcsr_type
   USE dbcsr_work_operations, ONLY: dbcsr_create
#include "base/dbcsr_base_uses.f90"

   IMPLICIT NONE
   PRIVATE

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbcsr_csr_conversions'

   LOGICAL, PARAMETER, PRIVATE          :: careful_mod = .FALSE.

   INTEGER, PARAMETER, PUBLIC           :: csr_dbcsr_blkrow_dist = 1, csr_eqrow_ceil_dist = 2, csr_eqrow_floor_dist = 3

   TYPE csr_mapping_data
      !! Mapping data relating local CSR indices to local indices of a block-row
      !! distributed (BRD) DBCSR matrix, and containing the block structure
      !! of the original DBCSR matrix from which the CSR matrix was created.

      PRIVATE
      INTEGER, DIMENSION(:), POINTER           :: csr_to_brd_ind => NULL(), &
                                                  brd_to_csr_ind => NULL()
         !! csr_to_brd_ind(csr_ind) gives the location of a matrix element with CSR index csr_ind (location in nzval_local) inside
         !! the data_area of the corresponding BRD matrix. If an element of the DBCSR matrix is treated as 0 in the CSR format, the
         !! index of this value is not in csr_to_brd_ind.
         !! same as csr_to_brd_ind but inverse mapping. If a given DBCSR index dbcsr_ind points to a zero element, then
         !! brd_to_csr_ind(dbcsr_ind) is -1.
      TYPE(dbcsr_type)                          :: brd_mat = dbcsr_type()
         !! DBCSR     BRD matrix acting as an intermediate step in any conversion from and to DBCSR format.

      LOGICAL                                  :: has_dbcsr_block_data = .FALSE.
         !! whether dbcsr_* fields are defined
      INTEGER                                  :: dbcsr_nblkcols_total = -1, &
                                                  dbcsr_nblkrows_total = -1, &
                                                  dbcsr_nblks_local = -1
         !! data from original DBCSR matrix (not block-row distributed),
         !! representing the original block structure.
      INTEGER, DIMENSION(:), POINTER           :: dbcsr_row_p => NULL(), dbcsr_col_i => NULL(), &
                                                  dbcsr_row_blk_size => NULL(), dbcsr_col_blk_size => NULL()
         !! data from original DBCSR matrix (not block-row distributed),
         !! representing the original block structure.
   END TYPE

   TYPE csr_data_area_type
      !! Data type of CSR matrices

      REAL(KIND=real_4), DIMENSION(:), POINTER      :: r_sp => Null()
         !! real, single precision data array
      REAL(KIND=real_8), DIMENSION(:), POINTER      :: r_dp => Null()
         !! real, double precision data array
      COMPLEX(KIND=real_4), DIMENSION(:), POINTER   :: c_sp => Null()
         !! complex, double precision data array
      COMPLEX(KIND=real_8), DIMENSION(:), POINTER   :: c_dp => Null()
      INTEGER                                       :: data_type = -1
         !! data type of CSR matrix
   END TYPE

   TYPE csr_type
      !! Type for CSR matrices

      INTEGER                                  :: nrows_total = -1, ncols_total = -1, &
                                                  nze_local = -1, nrows_local = -1
         !! total number of rows
         !! total number of columns
         !! local number of nonzero elements
         !! local number of rows
      TYPE(mp_comm_type)                       :: mp_group = mp_comm_type()
         !! message-passing group ID
      INTEGER(KIND=int_8)                      :: nze_total = -1_int_8
         !! total number of nonzero elements
      INTEGER, DIMENSION(:), POINTER           :: rowptr_local => NULL(), &
                                                  colind_local => NULL(), &
                                                  nzerow_local => NULL()
         !! indices of elements inside nzval_local starting a row
         !! column indices of elements inside nzval_local
      TYPE(csr_data_area_type)                 :: nzval_local = csr_data_area_type()
         !! values of local non-zero elements, row-wise ordering.
      TYPE(csr_mapping_data)                   :: dbcsr_mapping = csr_mapping_data()
         !! mapping data relating indices of nzval_local to indices of a block-row distributed DBCSR matrix
      LOGICAL                                  :: has_mapping = .FALSE.
         !! whether dbcsr_mapping is defined
      LOGICAL                                  :: valid = .FALSE.
         !! whether essential data (excluding dbcsr_mapping) is completely allocated
      LOGICAL                                  :: has_indices = .FALSE.
         !! whether rowptr_local and colind_local are defined
   END TYPE csr_type

   TYPE csr_p_type
      TYPE(csr_type), POINTER                  :: csr_mat => NULL()
   END TYPE csr_p_type

   PUBLIC :: csr_type, csr_p_type, convert_csr_to_dbcsr, &
             csr_create_from_dbcsr, &
             csr_destroy, &
             convert_dbcsr_to_csr, &
             csr_create_new, csr_create_template, &
             csr_print_sparsity, dbcsr_to_csr_filter, &
             csr_write

   INTERFACE csr_create
      MODULE PROCEDURE csr_create_new, csr_create_template
   END INTERFACE

CONTAINS

   SUBROUTINE csr_create_new(csr_mat, nrows_total, ncols_total, nze_total, &
                             nze_local, nrows_local, mp_group, data_type)
      !! Create a new CSR matrix and allocate all internal data (excluding dbcsr_mapping)

      TYPE(csr_type), INTENT(OUT)                        :: csr_mat
         !! CSR matrix to return
      INTEGER, INTENT(IN)                                :: nrows_total, ncols_total
         !! total number of rows
         !! total number of columns
      INTEGER(KIND=int_8)                                :: nze_total
         !! total number of non-zero elements
      INTEGER, INTENT(IN)                                :: nze_local, nrows_local
         !! local number of non-zero elements
         !! local number of rows
      TYPE(mp_comm_type), INTENT(IN)                     :: mp_group
      INTEGER, INTENT(IN), OPTIONAL                      :: data_type
         !! data type of the CSR matrix (default real double prec.)

      CHARACTER(LEN=*), PARAMETER :: routineN = 'csr_create_new'
      INTEGER                                            :: handle

      CALL timeset(routineN, handle)

      IF (nrows_total .LT. nrows_local) &
         DBCSR_ABORT("local number of rows must not exceed total number of rows")

      IF (nze_total .LT. nze_local) CALL dbcsr_abort(__LOCATION__, "local number of non-zero "// &
                                                     "elements must not exceed total number of non-zero elements")

      IF (INT(nrows_total, kind=int_8)*INT(ncols_total, kind=int_8) .LT. nze_total) &
         DBCSR_ABORT("Total number of non-zero elements must not exceed total matrix size")

      IF (INT(nrows_local, kind=int_8)*INT(ncols_total, kind=int_8) .LT. nze_local) &
         DBCSR_ABORT("Local number of non-zero elements must not exceed local matrix size")

      csr_mat%ncols_total = ncols_total
      csr_mat%nrows_total = nrows_total
      csr_mat%nze_total = nze_total
      csr_mat%nze_local = nze_local
      ALLOCATE (csr_mat%colind_local(nze_local))
      csr_mat%nrows_local = nrows_local
      ALLOCATE (csr_mat%rowptr_local(nrows_local + 1))
      ALLOCATE (csr_mat%nzerow_local(nrows_local))

      IF (PRESENT(data_type)) THEN
         csr_mat%nzval_local%data_type = data_type
      ELSE
         csr_mat%nzval_local%data_type = dbcsr_type_real_default
      END IF

      SELECT CASE (csr_mat%nzval_local%data_type)
      CASE (dbcsr_type_real_4)
         ALLOCATE (csr_mat%nzval_local%r_sp(nze_local))
      CASE (dbcsr_type_real_8)
         ALLOCATE (csr_mat%nzval_local%r_dp(nze_local))
      CASE (dbcsr_type_complex_4)
         ALLOCATE (csr_mat%nzval_local%c_sp(nze_local))
      CASE (dbcsr_type_complex_8)
         ALLOCATE (csr_mat%nzval_local%c_dp(nze_local))
      CASE DEFAULT
         DBCSR_ABORT("Invalid matrix type")
      END SELECT

      csr_mat%mp_group = mp_group

      csr_mat%valid = .TRUE.
      csr_mat%has_mapping = .FALSE.
      csr_mat%has_indices = .FALSE.

      CALL timestop(handle)

   END SUBROUTINE csr_create_new

   SUBROUTINE csr_create_template(matrix_b, matrix_a)
      !! Create a new CSR matrix and allocate all internal data using
      !! an existing CSR matrix. Copies the indices but no actual matrix data.

      TYPE(csr_type), INTENT(OUT)                        :: matrix_b
         !! Target CSR matrix
      TYPE(csr_type), INTENT(IN)                         :: matrix_a
         !! Source CSR matrix

      CHARACTER(LEN=*), PARAMETER :: routineN = 'csr_create_template'

      INTEGER                                            :: handle
      TYPE(csr_mapping_data)                             :: map

      CALL timeset(routineN, handle)

      IF (.NOT. matrix_a%valid) &
         DBCSR_ABORT("Source CSR matrix must be created.")

      CALL csr_create_new(matrix_b, matrix_a%nrows_total, matrix_a%ncols_total, &
                          matrix_a%nze_total, matrix_a%nze_local, matrix_a%nrows_local, &
                          matrix_a%mp_group, matrix_a%nzval_local%data_type)

      matrix_b%mp_group = matrix_a%mp_group
      matrix_b%has_mapping = matrix_a%has_mapping
      matrix_b%has_indices = matrix_a%has_indices

      IF (matrix_a%has_indices) THEN
         matrix_b%rowptr_local = matrix_a%rowptr_local
         matrix_b%nzerow_local = matrix_a%nzerow_local
         matrix_b%colind_local = matrix_a%colind_local
      END IF

      IF (matrix_a%has_mapping) THEN
         map = matrix_a%dbcsr_mapping
         ALLOCATE (matrix_b%dbcsr_mapping%csr_to_brd_ind(SIZE(map%csr_to_brd_ind)))
         ALLOCATE (matrix_b%dbcsr_mapping%brd_to_csr_ind(SIZE(map%brd_to_csr_ind)))
         matrix_b%dbcsr_mapping%csr_to_brd_ind = map%csr_to_brd_ind
         matrix_b%dbcsr_mapping%brd_to_csr_ind = map%brd_to_csr_ind
         matrix_b%dbcsr_mapping%has_dbcsr_block_data = map%has_dbcsr_block_data
         IF (matrix_b%dbcsr_mapping%has_dbcsr_block_data) THEN
            matrix_b%dbcsr_mapping%dbcsr_nblkcols_total = map%dbcsr_nblkcols_total
            matrix_b%dbcsr_mapping%dbcsr_nblkrows_total = map%dbcsr_nblkrows_total
            ALLOCATE (matrix_b%dbcsr_mapping%dbcsr_row_blk_size(map%dbcsr_nblkrows_total))
            ALLOCATE (matrix_b%dbcsr_mapping%dbcsr_col_blk_size(map%dbcsr_nblkcols_total))
            ALLOCATE (matrix_b%dbcsr_mapping%dbcsr_row_p(map%dbcsr_nblkrows_total + 1))
            ALLOCATE (matrix_b%dbcsr_mapping%dbcsr_col_i(map%dbcsr_nblks_local))
            matrix_b%dbcsr_mapping%dbcsr_nblks_local = map%dbcsr_nblks_local
            matrix_b%dbcsr_mapping%dbcsr_row_p = map%dbcsr_row_p
            matrix_b%dbcsr_mapping%dbcsr_col_i = map%dbcsr_col_i
            matrix_b%dbcsr_mapping%dbcsr_row_blk_size = map%dbcsr_row_blk_size
            matrix_b%dbcsr_mapping%dbcsr_col_blk_size = map%dbcsr_col_blk_size
         END IF

         CALL dbcsr_copy(matrix_b%dbcsr_mapping%brd_mat, map%brd_mat)

         matrix_b%valid = .TRUE.

      END IF

      CALL timestop(handle)
   END SUBROUTINE csr_create_template

   SUBROUTINE csr_create_nzerow(csr_mat, nzerow)
      !! create a vector containing the number of non-zero elements in each
      !! row of a CSR matrix

      TYPE(csr_type), INTENT(IN)                         :: csr_mat
         !! CSR matrix
      INTEGER, DIMENSION(:), INTENT(INOUT), POINTER      :: nzerow
         !! number of non-zero elements in each row

      CHARACTER(LEN=*), PARAMETER :: routineN = 'csr_create_nzerow'

      INTEGER                                            :: handle, k

      CALL timeset(routineN, handle)

      IF (.NOT. csr_mat%valid) &
         DBCSR_ABORT("CSR matrix must be created.")

      DO k = 1, csr_mat%nrows_local
         nzerow(k) = csr_mat%rowptr_local(k + 1) - csr_mat%rowptr_local(k)
      END DO

      CALL timestop(handle)
   END SUBROUTINE csr_create_nzerow

   SUBROUTINE csr_destroy(csr_mat)
      !! destroy a CSR matrix
      TYPE(csr_type), INTENT(INOUT)                      :: csr_mat

      CHARACTER(LEN=*), PARAMETER :: routineN = 'csr_destroy'
      INTEGER                                            :: handle
      TYPE(csr_mapping_data)                             :: map

      CALL timeset(routineN, handle)

      IF (.NOT. csr_mat%valid) &
         DBCSR_ABORT("CSR matrix must be created before destroying it.")

      IF (ASSOCIATED(csr_mat%rowptr_local)) DEALLOCATE (csr_mat%rowptr_local)
      IF (ASSOCIATED(csr_mat%nzerow_local)) DEALLOCATE (csr_mat%nzerow_local)
      IF (ASSOCIATED(csr_mat%colind_local)) DEALLOCATE (csr_mat%colind_local)

      IF (csr_mat%has_mapping) THEN
         map = csr_mat%dbcsr_mapping
         IF (ASSOCIATED(map%csr_to_brd_ind)) &
            DEALLOCATE (map%csr_to_brd_ind)
         IF (ASSOCIATED(map%brd_to_csr_ind)) &
            DEALLOCATE (map%brd_to_csr_ind)
         IF (ASSOCIATED(map%dbcsr_row_blk_size)) &
            DEALLOCATE (map%dbcsr_row_blk_size)
         IF (ASSOCIATED(map%dbcsr_col_blk_size)) &
            DEALLOCATE (map%dbcsr_col_blk_size)
         IF (ASSOCIATED(map%dbcsr_row_p)) &
            DEALLOCATE (map%dbcsr_row_p)
         IF (ASSOCIATED(map%dbcsr_col_i)) &
            DEALLOCATE (map%dbcsr_col_i)

         CALL dbcsr_release(map%brd_mat)
      END IF

      IF (ASSOCIATED(csr_mat%nzval_local%r_dp)) &
         DEALLOCATE (csr_mat%nzval_local%r_dp)
      IF (ASSOCIATED(csr_mat%nzval_local%r_sp)) &
         DEALLOCATE (csr_mat%nzval_local%r_sp)
      IF (ASSOCIATED(csr_mat%nzval_local%c_sp)) &
         DEALLOCATE (csr_mat%nzval_local%c_sp)
      IF (ASSOCIATED(csr_mat%nzval_local%c_dp)) &
         DEALLOCATE (csr_mat%nzval_local%c_dp)

      csr_mat%has_mapping = .FALSE.
      csr_mat%valid = .FALSE.
      csr_mat%dbcsr_mapping%has_dbcsr_block_data = .FALSE.
      csr_mat%has_indices = .FALSE.
      csr_mat%nzval_local%data_type = -1

      CALL timestop(handle)
   END SUBROUTINE csr_destroy

   SUBROUTINE csr_create_from_brd(brd_mat, csr_mat, csr_sparsity_brd)
      !! Allocate the internals of a CSR matrix using data from a block-row
      !! distributed DBCSR matrix

      TYPE(dbcsr_type), INTENT(IN)                       :: brd_mat
         !! block-row-distributed DBCSR matrix
      TYPE(csr_type), INTENT(OUT)                        :: csr_mat
         !! CSR matrix
      TYPE(dbcsr_type), INTENT(IN)                       :: csr_sparsity_brd
         !! BRD matrix representing sparsity pattern of CSR matrix

      CHARACTER(LEN=*), PARAMETER :: routineN = 'csr_create_from_brd'

      INTEGER                                            :: data_type, handle, &
                                                            nfullcols_total, nfullrows, &
                                                            nfullrows_total, nze_local
      INTEGER(KIND=int_8)                                :: nze_total
      INTEGER, DIMENSION(:), POINTER                     :: cdist, csr_index, dbcsr_index
      TYPE(dbcsr_distribution_obj)                       :: dist_current
      TYPE(mp_comm_type)                                 :: mp_group

      CALL timeset(routineN, handle)
      NULLIFY (dbcsr_index, csr_index, cdist)

      dist_current = dbcsr_distribution(brd_mat)

      mp_group = dbcsr_mp_group(dbcsr_distribution_mp(dist_current))
      cdist => dbcsr_distribution_col_dist(dist_current)

      IF (ANY(cdist .NE. 0)) &
         DBCSR_ABORT("DBCSR matrix not block-row distributed.")

      ! Calculate mapping between BRD and CSR indices
      CALL csr_get_dbcsr_mapping(brd_mat, dbcsr_index, csr_index, nze_local, &
                                 csr_sparsity_brd)

      CALL dbcsr_get_info(brd_mat, nfullrows_total=nfullrows_total, &
                          nfullcols_total=nfullcols_total)

      ! Sum up local number of non-zero elements to get total number
      nze_total = nze_local
      CALL mp_sum(nze_total, mp_group)

      nfullrows = dbcsr_nfullrows_local(brd_mat)
      data_type = dbcsr_get_data_type(brd_mat)

      ! Allocate CSR matrix
      CALL csr_create_new(csr_mat, nfullrows_total, nfullcols_total, nze_total, &
                          nze_local, nfullrows, mp_group, data_type)

      csr_mat%dbcsr_mapping%brd_to_csr_ind => csr_index
      csr_mat%dbcsr_mapping%csr_to_brd_ind => dbcsr_index

      csr_mat%has_mapping = .TRUE.
      csr_mat%dbcsr_mapping%brd_mat = brd_mat

      CALL timestop(handle)
   END SUBROUTINE csr_create_from_brd

   SUBROUTINE csr_get_dbcsr_mapping(brd_mat, dbcsr_index, csr_index, csr_nze_local, &
                                    csr_sparsity_brd)
      !! create the mapping information between a block-row distributed DBCSR
      !! matrix and the corresponding CSR matrix

      TYPE(dbcsr_type), INTENT(IN)                       :: brd_mat
         !! the block-row distributed DBCSR matrix
      INTEGER, DIMENSION(:), INTENT(OUT), POINTER        :: dbcsr_index, csr_index
         !! csr to dbcsr index mapping
         !! dbcsr to csr index mapping
      INTEGER, INTENT(OUT)                               :: csr_nze_local
         !! number of local non-zero elements
      TYPE(dbcsr_type), INTENT(IN)                       :: csr_sparsity_brd
         !! sparsity of CSR matrix represented in BRD format

      CHARACTER(LEN=*), PARAMETER :: routineN = 'csr_get_dbcsr_mapping'

      INTEGER :: blk, blkcol, blkrow, col_blk_size, csr_ind, data_type, dbcsr_ind, el_sum, &
                 fullcol_sum_blkrow, handle, l, m, n, nblkrows_total, nze, prev_blk, prev_blkrow, &
                 prev_row_blk_size, row_blk_offset, row_blk_size
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: csr_nze, nfullcol_blkrow
      INTEGER, DIMENSION(:), POINTER                     :: dbcsr_index_nozeroes
      LOGICAL                                            :: tr
      TYPE(dbcsr_iterator)                               :: iter

      CALL timeset(routineN, handle)

      m = 0
      dbcsr_ind = 0
      fullcol_sum_blkrow = 0
      NULLIFY (dbcsr_index, csr_index)

      CALL dbcsr_get_info(brd_mat, nblkrows_total=nblkrows_total)
      nze = dbcsr_get_nze(brd_mat)

      ALLOCATE (nfullcol_blkrow(nblkrows_total))
      ALLOCATE (dbcsr_index(nze))
      ALLOCATE (csr_index(nze))

      CALL dbcsr_iterator_start(iter, brd_mat, read_only=.TRUE.)
      nfullcol_blkrow = 0 ! number of non-zero full columns in each block row
      prev_blk = 0

      DO WHILE (dbcsr_iterator_blocks_left(iter))
         CALL dbcsr_iterator_next_block(iter, blkrow, blkcol, blk, transposed=tr, &
                                        col_size=col_blk_size)

         IF (blk /= prev_blk + 1) &
            DBCSR_ABORT("iterator is required to traverse the blocks in a row-wise fashion")

         prev_blk = blk

         nfullcol_blkrow(blkrow) = nfullcol_blkrow(blkrow) + col_blk_size
         IF (tr) &
            DBCSR_ABORT("DBCSR block data must not be transposed")
      END DO
      CALL dbcsr_iterator_stop(iter)

      el_sum = 0 ! number of elements above current block row

      prev_blkrow = 0 ! store number and size of previous block row
      prev_row_blk_size = 0

      CALL dbcsr_iterator_start(iter, brd_mat, read_only=.TRUE.)

      DO WHILE (dbcsr_iterator_blocks_left(iter))

         CALL dbcsr_iterator_next_block(iter, blkrow, blkcol, blk, transposed=tr, &
                                        row_size=row_blk_size, col_size=col_blk_size, row_offset=row_blk_offset)

         IF (blkrow .GT. prev_blkrow) THEN ! new block row
            IF (prev_blkrow .GT. 0) THEN
               el_sum = el_sum + nfullcol_blkrow(prev_blkrow)*prev_row_blk_size
            END IF

            ! number of non-zero full columns on the left of current block:
            fullcol_sum_blkrow = 0

            dbcsr_ind = el_sum
         END IF
         DO n = 1, col_blk_size !nr of columns
            DO m = 1, row_blk_size !nr of rows
               dbcsr_ind = dbcsr_ind + 1
               csr_ind = (m - 1)*nfullcol_blkrow(blkrow) + fullcol_sum_blkrow + n + el_sum
               dbcsr_index(csr_ind) = dbcsr_ind
               csr_index(dbcsr_ind) = csr_ind
            END DO
         END DO
         fullcol_sum_blkrow = fullcol_sum_blkrow + col_blk_size
         prev_blkrow = blkrow
         prev_row_blk_size = row_blk_size
      END DO
      CALL dbcsr_iterator_stop(iter)

      ! remove BRD zero elements from CSR format
      data_type = dbcsr_get_data_type(csr_sparsity_brd)
      ALLOCATE (csr_nze(nze))

      SELECT CASE (data_type)
      CASE (dbcsr_type_real_4)
         csr_nze(:) = INT(csr_sparsity_brd%data_area%d%r_sp(1:nze))
      CASE (dbcsr_type_real_8)
         csr_nze(:) = INT(csr_sparsity_brd%data_area%d%r_dp(1:nze))
      CASE DEFAULT
         DBCSR_ABORT("CSR sparsity matrix must have a real datatype")
      END SELECT

      IF (ANY(csr_nze .EQ. 0)) THEN
         ALLOCATE (dbcsr_index_nozeroes(SUM(csr_nze)))
         m = 0 ! csr index if zeroes are excluded from CSR data
         DO l = 1, nze ! csr index if zeroes are included in CSR data
            IF (csr_nze(dbcsr_index(l)) .EQ. 0) THEN
               csr_index(dbcsr_index(l)) = -1
            ELSE
               m = m + 1
               dbcsr_index_nozeroes(m) = dbcsr_index(l)
               csr_index(dbcsr_index(l)) = m
            END IF
         END DO
         DEALLOCATE (dbcsr_index)
         dbcsr_index => dbcsr_index_nozeroes
      END IF

      IF (ANY(csr_nze .EQ. 0)) THEN
         csr_nze_local = m
      ELSE
         csr_nze_local = nze
      END IF

      CALL timestop(handle)
   END SUBROUTINE csr_get_dbcsr_mapping

   SUBROUTINE convert_csr_to_brd(brd_mat, csr_mat)
      !! Copies data from a CSR matrix to a block-row distributed DBCSR matrix.
      !! The DBCSR matrix must have a block structure consistent with the CSR matrix.

      TYPE(dbcsr_type), INTENT(INOUT)                    :: brd_mat
         !! block-row distributed DBCSR matrix
      TYPE(csr_type), INTENT(IN)                         :: csr_mat
         !! CSR matrix

      CHARACTER(LEN=*), PARAMETER :: routineN = 'convert_csr_to_brd'

      INTEGER                                            :: data_type, handle, ind, k, nze

      CALL timeset(routineN, handle)

      data_type = dbcsr_get_data_type(brd_mat)
      nze = dbcsr_get_nze(brd_mat)
      CALL dbcsr_data_release(brd_mat%data_area)
      CALL dbcsr_data_new(brd_mat%data_area, data_type, nze)

      SELECT CASE (data_type)
      CASE (dbcsr_type_real_4)
         brd_mat%data_area%d%r_sp(1:nze) = 0.0_sp
      CASE (dbcsr_type_real_8)
         brd_mat%data_area%d%r_dp(1:nze) = 0.0_dp
      CASE (dbcsr_type_complex_4)
         brd_mat%data_area%d%c_sp(1:nze) = 0.0_sp
      CASE (dbcsr_type_complex_8)
         brd_mat%data_area%d%c_dp(1:nze) = 0.0_dp
      END SELECT

      DO k = 1, csr_mat%nze_local
         ind = csr_mat%dbcsr_mapping%csr_to_brd_ind(k)
         SELECT CASE (data_type)
         CASE (dbcsr_type_real_4)
            brd_mat%data_area%d%r_sp(ind) = csr_mat%nzval_local%r_sp(k)
         CASE (dbcsr_type_real_8)
            brd_mat%data_area%d%r_dp(ind) = csr_mat%nzval_local%r_dp(k)
         CASE (dbcsr_type_complex_4)
            brd_mat%data_area%d%c_sp(ind) = csr_mat%nzval_local%c_sp(k)
         CASE (dbcsr_type_complex_8)
            brd_mat%data_area%d%c_dp(ind) = csr_mat%nzval_local%c_dp(k)
         END SELECT
      END DO

      CALL timestop(handle)
   END SUBROUTINE convert_csr_to_brd

   SUBROUTINE convert_brd_to_csr(brd_mat, csr_mat)
      !! Convert a block-row distributed DBCSR matrix to a CSR matrix
      !! The DBCSR matrix must have a block structure consistent with the CSR matrix.

      TYPE(dbcsr_type), INTENT(IN)                       :: brd_mat
         !! block-row distributed DBCSR matrix
      TYPE(csr_type), INTENT(INOUT)                      :: csr_mat
         !! CSR matrix

      CHARACTER(LEN=*), PARAMETER :: routineN = 'convert_brd_to_csr'

      INTEGER :: blk, blkcol, blkrow, col_blk_offset, col_blk_size, csr_ind, data_type, dbcsr_ind, &
                 el_sum, handle, ind_blk_data, k, local_row_ind, m, n, nblkrows_total, node_row_offset, &
                 prev_blkrow, prev_row_blk_size, row_blk_offset, row_blk_size
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: nfullcol_blkrow
      INTEGER, DIMENSION(:), POINTER                     :: colind, csr_index, dbcsr_index, nzerow, &
                                                            rowptr
      LOGICAL                                            :: new_ind, tr
      TYPE(dbcsr_data_obj)                               :: block
      TYPE(dbcsr_iterator)                               :: iter

      CALL timeset(routineN, handle)
      local_row_ind = 0
      dbcsr_ind = 0
      node_row_offset = 0
      NULLIFY (rowptr, colind, dbcsr_index, csr_index)

      dbcsr_index => csr_mat%dbcsr_mapping%csr_to_brd_ind
      csr_index => csr_mat%dbcsr_mapping%brd_to_csr_ind

      ! CSR indices are not recalculated if indices are already defined
      new_ind = .NOT. (csr_mat%has_indices)

      IF (.NOT. csr_mat%has_mapping) &
         DBCSR_ABORT("DBCSR mapping of CSR matrix must be defined")

      ! Calculate mapping between CSR matrix and DBCSR matrix if not yet defined
      !IF (.NOT. csr_mat%has_mapping ) THEN
      !  CALL csr_get_dbcsr_mapping (brd_mat, dbcsr_index, csr_index, nze)
      !ENDIF

      CALL dbcsr_get_info(brd_mat, nblkrows_total=nblkrows_total)
      ALLOCATE (nfullcol_blkrow(nblkrows_total))

      ! iteration over blocks without touching data,
      ! in order to get number of non-zero full columns in each block row
      CALL dbcsr_iterator_start(iter, brd_mat, read_only=.TRUE.)
      blkrow = 0
      nfullcol_blkrow = 0 ! number of non-zero full columns in each block row
      data_type = dbcsr_get_data_type(brd_mat)

      DO WHILE (dbcsr_iterator_blocks_left(iter))
         CALL dbcsr_iterator_next_block(iter, blkrow, blkcol, blk, col_size=col_blk_size, &
                                        row_offset=row_blk_offset)
         nfullcol_blkrow(blkrow) = nfullcol_blkrow(blkrow) + col_blk_size
         IF (blk .EQ. 1) THEN
            node_row_offset = row_blk_offset
         END IF
      END DO

      CALL dbcsr_iterator_stop(iter)

      ! Copy data from BRD matrix to CSR matrix and calculate CSR indices
      prev_blkrow = 0
      prev_row_blk_size = 0
      el_sum = 0 ! number of elements above current block row
      colind => csr_mat%colind_local
      rowptr => csr_mat%rowptr_local
      nzerow => csr_mat%nzerow_local
      CALL dbcsr_data_init(block)
      CALL dbcsr_data_new(block, data_type)

      CALL dbcsr_iterator_start(iter, brd_mat, read_only=.TRUE.)

      IF (new_ind) rowptr(:) = 0 ! initialize to 0 in order to check which rows are 0 at a later time
      DO WHILE (dbcsr_iterator_blocks_left(iter))
         CALL dbcsr_iterator_next_block(iter, blkrow, blkcol, block, tr, &
                                        col_size=col_blk_size, row_size=row_blk_size, row_offset=row_blk_offset, &
                                        col_offset=col_blk_offset)

         IF (tr) &
            DBCSR_ABORT("DBCSR block data must not be transposed")

         IF (blkrow > prev_blkrow) THEN ! new block row
            local_row_ind = row_blk_offset - node_row_offset ! actually: local row index - 1
            IF (prev_blkrow .GT. 0) THEN
               el_sum = el_sum + nfullcol_blkrow(prev_blkrow)*prev_row_blk_size
            END IF
            dbcsr_ind = el_sum
         END IF
         DO n = 1, col_blk_size !nr of columns
            DO m = 1, row_blk_size !nr of rows
               dbcsr_ind = dbcsr_ind + 1
               csr_ind = csr_index(dbcsr_ind) ! get CSR index for current element
               IF (csr_ind .GT. 0) THEN ! is non-zero element if csr_ind > 0
                  IF (new_ind) THEN
                     ! Calculate CSR column index
                     colind(csr_ind) = col_blk_offset + n - 1
                     ! Calculate CSR row pointer
                     ! (not accounting for zero elements inside non-zero blocks)
                     IF (rowptr(local_row_ind + m) .LE. 0) rowptr(local_row_ind + m) = &
                        rowptr(local_row_ind + m) + el_sum + 1 + nfullcol_blkrow(blkrow)*(m - 1)
                  END IF
                  ind_blk_data = (m + row_blk_size*(n - 1)) ! index of data inside DBCSR blocks
                  SELECT CASE (csr_mat%nzval_local%data_type)
                  CASE (dbcsr_type_real_4)
                     csr_mat%nzval_local%r_sp(csr_ind) = block%d%r_sp(ind_blk_data)
                  CASE (dbcsr_type_real_8)
                     csr_mat%nzval_local%r_dp(csr_ind) = block%d%r_dp(ind_blk_data)
                  CASE (dbcsr_type_complex_4)
                     csr_mat%nzval_local%c_sp(csr_ind) = block%d%c_sp(ind_blk_data)
                  CASE (dbcsr_type_complex_8)
                     csr_mat%nzval_local%c_dp(csr_ind) = block%d%c_dp(ind_blk_data)
                  END SELECT
               ELSE ! is zero element if ind = -1
                  ! CSR row pointer has to be corrected if element is zero
                  ! (subtract 1 from all subsequent row pointers)
                  IF (new_ind) rowptr(local_row_ind + m + 1:) = rowptr(local_row_ind + m + 1:) - 1
               END IF
            END DO
         END DO
         prev_blkrow = blkrow
         prev_row_blk_size = row_blk_size
      END DO

      IF (new_ind) THEN
         ! Repeat previous row pointer for row pointers to zero rows
         IF (csr_mat%nrows_local .GT. 0) rowptr(1) = 1
         DO k = 1, csr_mat%nrows_local
            IF (rowptr(k) .LE. 0) rowptr(k) = rowptr(k - 1)
         END DO

         rowptr(csr_mat%nrows_local + 1) = csr_mat%nze_local + 1
      END IF

      CALL csr_create_nzerow(csr_mat, nzerow)

      CALL dbcsr_iterator_stop(iter)
      CALL dbcsr_data_clear_pointer(block)
      CALL dbcsr_data_release(block)

      IF (new_ind) csr_mat%has_indices = .TRUE.

      CALL timestop(handle)
   END SUBROUTINE convert_brd_to_csr

   SUBROUTINE csr_create_from_dbcsr(dbcsr_mat, csr_mat, dist_format, csr_sparsity, numnodes)
      !! create CSR matrix including dbcsr_mapping from arbitrary DBCSR matrix
      !! in order to prepare conversion.

      TYPE(dbcsr_type), INTENT(IN)                       :: dbcsr_mat
      TYPE(csr_type), INTENT(OUT)                        :: csr_mat
      INTEGER, INTENT(IN)                                :: dist_format
         !! how to distribute CSR rows over processes: csr_dbcsr_blkrow_dist: the number of rows per process is adapted to the row
         !! block sizes in the DBCSR format such that blocks are not split over different processes. csr_eqrow_ceil_dist: each
         !! process holds ceiling(N/P) CSR rows. csr_eqrow_floor_dist: each process holds floor(N/P) CSR rows.
      TYPE(dbcsr_type), INTENT(IN), OPTIONAL             :: csr_sparsity
         !! DBCSR matrix containing 0 and 1, representing CSR sparsity pattern 1: non-zero element 0: zero element (not present in
         !! CSR format) Note: matrix must be of data_type dbcsr_type_real_4 or dbcsr_type_real_8 (integer types not supported)
      INTEGER, INTENT(IN), OPTIONAL                      :: numnodes
         !! number of nodes to use for distributing CSR matrix (optional, default is number of nodes used for DBCSR matrix)

      CHARACTER(LEN=*), PARAMETER :: routineN = 'csr_create_from_dbcsr'

      INTEGER                                            :: dbcsr_numnodes, handle, nblkcols_total, &
                                                            nblkrows_total, nblks_local, num_p
      LOGICAL                                            :: equal_dist, floor_dist
      TYPE(dbcsr_type)                                   :: brd_mat, csr_sparsity_brd, &
                                                            csr_sparsity_nosym, dbcsr_mat_nosym

      CALL timeset(routineN, handle)

      IF (.NOT. dbcsr_valid_index(dbcsr_mat)) &
         DBCSR_ABORT("Invalid DBCSR matrix")

      SELECT CASE (dist_format)
      CASE (csr_dbcsr_blkrow_dist)
         equal_dist = .FALSE.
         floor_dist = .FALSE.
      CASE (csr_eqrow_ceil_dist)
         equal_dist = .TRUE.
         floor_dist = .FALSE.
      CASE (csr_eqrow_floor_dist)
         equal_dist = .TRUE.
         floor_dist = .TRUE.
      END SELECT

      ! Conversion does not support matrices in symmetric format, therefore desymmetrize
      IF (dbcsr_has_symmetry(dbcsr_mat)) THEN
         CALL dbcsr_desymmetrize_deep(dbcsr_mat, dbcsr_mat_nosym, untransposed_data=.TRUE.)
      ELSE
         CALL dbcsr_copy(dbcsr_mat_nosym, dbcsr_mat)
      END IF

      IF (PRESENT(csr_sparsity)) THEN
         IF (dbcsr_has_symmetry(csr_sparsity)) THEN
            CALL dbcsr_desymmetrize_deep(csr_sparsity, csr_sparsity_nosym, &
                                         untransposed_data=.TRUE.)
         ELSE
            CALL dbcsr_copy(csr_sparsity_nosym, csr_sparsity)
         END IF
      ELSE
         CALL dbcsr_create(csr_sparsity_nosym, &
                           template=dbcsr_mat_nosym, &
                           name="CSR sparsity matrix", &
                           data_type=dbcsr_type_real_8)
         CALL dbcsr_copy(csr_sparsity_nosym, dbcsr_mat_nosym)
         CALL dbcsr_set(csr_sparsity_nosym, 1.0_dp)
      END IF

      IF (.NOT. dbcsr_has_same_block_structure(dbcsr_mat_nosym, csr_sparsity_nosym)) &
         DBCSR_ABORT("csr_sparsity and dbcsr_mat have different sparsity pattern")

      dbcsr_numnodes = dbcsr_mp_numnodes(dbcsr_distribution_mp(dbcsr_distribution(dbcsr_mat)))
      IF (PRESENT(numnodes)) THEN
         IF (numnodes .GT. dbcsr_numnodes) &
            CALL dbcsr_abort(__LOCATION__, "Number of nodes used for CSR matrix "// &
                             "must not exceed total number of nodes")

         num_p = numnodes
      ELSE
         num_p = dbcsr_numnodes
      END IF

      CALL dbcsr_create_brd(dbcsr_mat_nosym, brd_mat, equal_dist, floor_dist, &
                            num_p)
      CALL dbcsr_create_brd(csr_sparsity_nosym, csr_sparsity_brd, equal_dist, floor_dist, &
                            num_p)

      ! Create CSR matrix from BRD matrix
      CALL csr_create_from_brd(brd_mat, csr_mat, csr_sparsity_brd)

      ! Store DBCSR block data inside CSR matrix
      ! (otherwise, this data is lost when converting from DBCSR to CSR)
      nblks_local = dbcsr_get_num_blocks(dbcsr_mat_nosym)
      nblkrows_total = dbcsr_nblkrows_total(dbcsr_mat_nosym)
      nblkcols_total = dbcsr_nblkcols_total(dbcsr_mat_nosym)

      csr_mat%dbcsr_mapping%dbcsr_nblkcols_total = nblkcols_total
      csr_mat%dbcsr_mapping%dbcsr_nblkrows_total = nblkrows_total
      csr_mat%dbcsr_mapping%dbcsr_nblks_local = nblks_local
      ALLOCATE (csr_mat%dbcsr_mapping%dbcsr_row_p(nblkrows_total + 1))
      csr_mat%dbcsr_mapping%dbcsr_row_p = dbcsr_mat_nosym%row_p
      ALLOCATE (csr_mat%dbcsr_mapping%dbcsr_col_i(nblks_local))
      csr_mat%dbcsr_mapping%dbcsr_col_i = dbcsr_mat_nosym%col_i

      ALLOCATE (csr_mat%dbcsr_mapping%dbcsr_row_blk_size(nblkrows_total))
      ALLOCATE (csr_mat%dbcsr_mapping%dbcsr_col_blk_size(nblkcols_total))

      csr_mat%dbcsr_mapping%dbcsr_row_blk_size = dbcsr_row_block_sizes(dbcsr_mat_nosym)
      csr_mat%dbcsr_mapping%dbcsr_col_blk_size = dbcsr_col_block_sizes(dbcsr_mat_nosym)

      csr_mat%dbcsr_mapping%has_dbcsr_block_data = .TRUE.

      CALL dbcsr_release(dbcsr_mat_nosym)
      CALL dbcsr_release(csr_sparsity_nosym)
      CALL dbcsr_release(csr_sparsity_brd)

      CALL timestop(handle)
   END SUBROUTINE csr_create_from_dbcsr

   FUNCTION dbcsr_has_same_block_structure(matrix_a, matrix_b) RESULT(is_equal)
      !! Helper function to assert that two DBCSR matrices have the same block
      !! structure and same sparsity pattern

      TYPE(dbcsr_type), INTENT(IN)                       :: matrix_a, matrix_b
      LOGICAL                                            :: is_equal
         !! whether matrix_a and matrix_b have the same block structure

      is_equal = .TRUE.

      IF (dbcsr_nblkcols_total(matrix_a) .NE. dbcsr_nblkcols_total(matrix_b)) is_equal = .FALSE.
      IF (dbcsr_nblkrows_total(matrix_a) .NE. dbcsr_nblkrows_total(matrix_b)) is_equal = .FALSE.
      IF ((matrix_a%nblks) .NE. (matrix_b%nblks)) is_equal = .FALSE.
      IF (ANY(matrix_a%row_p .NE. matrix_b%row_p)) is_equal = .FALSE.
      IF (ANY(matrix_a%col_i .NE. matrix_b%col_i)) is_equal = .FALSE.
      IF (ANY(dbcsr_row_block_sizes(matrix_a) .NE. &
              dbcsr_row_block_sizes(matrix_b))) is_equal = .FALSE.
      IF (ANY(dbcsr_row_block_sizes(matrix_a) .NE. &
              dbcsr_row_block_sizes(matrix_b))) is_equal = .FALSE.

   END FUNCTION dbcsr_has_same_block_structure

   SUBROUTINE csr_assert_consistency_with_dbcsr(csr_mat, dbcsr_mat)
      !! Helper function to assert that a given CSR matrix and a given DBCSR
      !! matrix are consistent before doing the conversion

      TYPE(csr_type), INTENT(IN)                         :: csr_mat
      TYPE(dbcsr_type), INTENT(IN)                       :: dbcsr_mat

      CHARACTER(LEN=*), PARAMETER :: routineN = 'csr_assert_consistency_with_dbcsr'

      INTEGER                                            :: handle
      TYPE(csr_mapping_data)                             :: map

      CALL timeset(routineN, handle)
      map = csr_mat%dbcsr_mapping
      IF (map%has_dbcsr_block_data) THEN
         IF (map%dbcsr_nblkcols_total .NE. dbcsr_nblkcols_total(dbcsr_mat)) &
            CALL dbcsr_abort(__LOCATION__, &
                             "field nblkcols_total of DBCSR matrix not consistent with CSR matrix")
         IF (map%dbcsr_nblkrows_total .NE. dbcsr_nblkrows_total(dbcsr_mat)) &
            CALL dbcsr_abort(__LOCATION__, &
                             "field nblkrows_total of DBCSR matrix not consistent with CSR matrix")
         IF (map%dbcsr_nblks_local .NE. dbcsr_mat%nblks) &
            CALL dbcsr_abort(__LOCATION__, &
                             "field nblks of DBCSR matrix not consistent with CSR matrix")
         IF (ANY(map%dbcsr_row_p .NE. dbcsr_mat%row_p)) &
            CALL dbcsr_abort(__LOCATION__, &
                             "field row_p of DBCSR matrix not consistent with CSR matrix")
         IF (ANY(map%dbcsr_col_i .NE. dbcsr_mat%col_i)) &
            CALL dbcsr_abort(__LOCATION__, &
                             "field dbcsr_col_i of DBCSR matrix not consistent with CSR matrix")
         IF (ANY(map%dbcsr_row_blk_size .NE. dbcsr_row_block_sizes(dbcsr_mat))) &
            CALL dbcsr_abort(__LOCATION__, &
                             "field row_blk_size of DBCSR matrix not consistent with CSR matrix")
         IF (ANY(map%dbcsr_col_blk_size .NE. dbcsr_col_block_sizes(dbcsr_mat))) &
            CALL dbcsr_abort(__LOCATION__, &
                             "field col_blk_size of DBCSR matrix not consistent with CSR matrix")
      ELSE
         CALL dbcsr_warn(__LOCATION__, "Can not assert consistency of the matrices "// &
                         "as no block data stored in CSR matrix.")
      END IF
      CALL timestop(handle)
   END SUBROUTINE csr_assert_consistency_with_dbcsr

   SUBROUTINE convert_dbcsr_to_csr(dbcsr_mat, csr_mat)
      !! Convert a DBCSR matrix to a CSR matrix.

      TYPE(dbcsr_type), INTENT(IN)                       :: dbcsr_mat
         !! DBCSR matrix to convert
      TYPE(csr_type), INTENT(INOUT)                      :: csr_mat
         !! correctly allocated CSR matrix

      CHARACTER(LEN=*), PARAMETER :: routineN = 'convert_dbcsr_to_csr'

      INTEGER                                            :: handle
      TYPE(dbcsr_type)                                   :: dbcsr_mat_nosym

      CALL timeset(routineN, handle)

      IF (.NOT. dbcsr_valid_index(dbcsr_mat)) &
         DBCSR_ABORT("Invalid DBCSR matrix")
      IF (dbcsr_get_data_type(dbcsr_mat) /= csr_mat%nzval_local%data_type) &
         DBCSR_ABORT("DBCSR and CSR matrix must have same type")

      IF (.NOT. csr_mat%has_mapping) &
         DBCSR_ABORT("CSR_mat must contain mapping to DBCSR format")

      IF (dbcsr_has_symmetry(dbcsr_mat)) THEN
         CALL dbcsr_desymmetrize_deep(dbcsr_mat, dbcsr_mat_nosym, untransposed_data=.TRUE.)
      ELSE
         dbcsr_mat_nosym = dbcsr_mat
      END IF

      CALL csr_assert_consistency_with_dbcsr(csr_mat, dbcsr_mat_nosym)

      ! 1) DBCSR -> BRD
      CALL dbcsr_complete_redistribute(dbcsr_mat_nosym, csr_mat%dbcsr_mapping%brd_mat)
      ! 2) BRD -> CSR
      CALL convert_brd_to_csr(csr_mat%dbcsr_mapping%brd_mat, csr_mat)

      IF (dbcsr_has_symmetry(dbcsr_mat)) CALL dbcsr_release(dbcsr_mat_nosym)

      CALL timestop(handle)
   END SUBROUTINE convert_dbcsr_to_csr

   SUBROUTINE convert_csr_to_dbcsr(dbcsr_mat, csr_mat)
      !! convert a CSR matrix to a DBCSR matrix

      TYPE(dbcsr_type), INTENT(INOUT)                    :: dbcsr_mat
         !! correctly allocated DBCSR matrix
      TYPE(csr_type), INTENT(INOUT)                      :: csr_mat
         !! CSR matrix to convert

      CHARACTER(LEN=*), PARAMETER :: routineN = 'convert_csr_to_dbcsr'

      INTEGER                                            :: handle
      TYPE(dbcsr_type)                                   :: dbcsr_mat_nosym

      CALL timeset(routineN, handle)

      IF (.NOT. dbcsr_valid_index(dbcsr_mat)) &
         DBCSR_ABORT("Invalid DBCSR matrix")

      IF (dbcsr_get_data_type(dbcsr_mat) /= csr_mat%nzval_local%data_type) &
         DBCSR_ABORT("DBCSR and CSR matrix must have same type")

      IF (.NOT. csr_mat%has_mapping) &
         DBCSR_ABORT("CSR_mat must contain mapping to DBCSR format")

      ! Desymmetrize to assert that DBCSR matrix has sparsity pattern consistent with CSR matrix
      IF (dbcsr_has_symmetry(dbcsr_mat)) THEN
         CALL dbcsr_desymmetrize_deep(dbcsr_mat, dbcsr_mat_nosym, untransposed_data=.TRUE.)
      ELSE
         dbcsr_mat_nosym = dbcsr_mat
      END IF

      CALL csr_assert_consistency_with_dbcsr(csr_mat, dbcsr_mat_nosym)

      IF (dbcsr_has_symmetry(dbcsr_mat)) CALL dbcsr_release(dbcsr_mat_nosym)

      ! 1) CSR -> BRD
      CALL convert_csr_to_brd(csr_mat%dbcsr_mapping%brd_mat, csr_mat)

      ! 2) BRD -> DBCSR
      CALL dbcsr_complete_redistribute(csr_mat%dbcsr_mapping%brd_mat, dbcsr_mat)

      CALL timestop(handle)
   END SUBROUTINE convert_csr_to_dbcsr

   SUBROUTINE dbcsr_to_csr_filter(dbcsr_mat, csr_sparsity, eps)
      !! Apply filtering threshold eps to DBCSR blocks in order to improve
      !! CSR sparsity (currently only used for testing purposes)

      TYPE(dbcsr_type), INTENT(IN)                       :: dbcsr_mat
      TYPE(dbcsr_type), INTENT(OUT)                      :: csr_sparsity
      REAL(kind=real_8), INTENT(IN)                      :: eps

      INTEGER                                            :: blkcol, blkrow, col_blk_size, data_type, &
                                                            row_blk_size
      LOGICAL                                            :: tr
      REAL(kind=real_8), ALLOCATABLE, DIMENSION(:)       :: block_abs, csr_sparsity_blk
      TYPE(dbcsr_data_obj)                               :: block
      TYPE(dbcsr_iterator)                               :: iter

!REAL(kind=real_8), DIMENSION(:), POINTER :: block

      CALL dbcsr_create(csr_sparsity, &
                        template=dbcsr_mat, &
                        name="CSR sparsity", &
                        data_type=dbcsr_type_real_8)
      CALL dbcsr_copy(csr_sparsity, dbcsr_mat)
      CALL dbcsr_set(csr_sparsity, 1.0_dp)

      IF (eps .GT. 0.0_dp) THEN
         data_type = dbcsr_get_data_type(dbcsr_mat)
         CALL dbcsr_data_init(block)
         CALL dbcsr_data_new(block, data_type)
         CALL dbcsr_iterator_start(iter, dbcsr_mat, read_only=.TRUE.)
         DO WHILE (dbcsr_iterator_blocks_left(iter))
            CALL dbcsr_iterator_next_block(iter, blkrow, blkcol, block, transposed=tr, &
                                           row_size=row_blk_size, col_size=col_blk_size)

            ALLOCATE (block_abs(row_blk_size*col_blk_size))
            ALLOCATE (csr_sparsity_blk(row_blk_size*col_blk_size))
            SELECT CASE (data_type)
            CASE (dbcsr_type_real_4)
               block_abs(:) = REAL(ABS(block%d%r_sp(:)), KIND=real_8)
            CASE (dbcsr_type_real_8)
               block_abs(:) = REAL(ABS(block%d%r_dp(:)), KIND=real_8)
            CASE (dbcsr_type_complex_4)
               block_abs(:) = REAL(ABS(block%d%c_sp(:)), KIND=real_8)
            CASE (dbcsr_type_complex_8)
               block_abs(:) = REAL(ABS(block%d%c_dp(:)), KIND=real_8)
            END SELECT

            csr_sparsity_blk = 1.0_dp
            WHERE (block_abs .LT. eps) csr_sparsity_blk = 0.0_dp
            CALL dbcsr_put_block(csr_sparsity, blkrow, blkcol, csr_sparsity_blk, transposed=tr)
            DEALLOCATE (csr_sparsity_blk, block_abs)
         END DO
         CALL dbcsr_iterator_stop(iter)
         CALL dbcsr_data_clear_pointer(block)
         CALL dbcsr_data_release(block)
      END IF

   END SUBROUTINE dbcsr_to_csr_filter

   SUBROUTINE csr_write(csr_mat, unit_nr, upper_triangle, threshold, binary)
      !! Write a CSR matrix to file

      TYPE(csr_type), INTENT(IN)                         :: csr_mat
      INTEGER, INTENT(IN)                                :: unit_nr
         !! unit number to which output is written
      LOGICAL, INTENT(IN), OPTIONAL                      :: upper_triangle
         !! If true (default: false), write only upper triangular part of matrix
      REAL(KIND=real_8), INTENT(IN), OPTIONAL            :: threshold
         !! threshold on the absolute value of the elements to be printed
      LOGICAL, INTENT(IN), OPTIONAL                      :: binary

      CHARACTER(LEN=*), PARAMETER :: routineN = 'csr_write'
      CHARACTER(LEN=default_string_length)               :: data_format
      COMPLEX(KIND=real_4), ALLOCATABLE, DIMENSION(:)    :: nzval_to_master_c_sp
      COMPLEX(KIND=real_8), ALLOCATABLE, DIMENSION(:)    :: nzval_to_master_c_dp
      INTEGER                                            :: handle, i, ii, k, l, m, mynode, &
                                                            numnodes, rowind, tag1, tag2, tag3
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: colind_to_master, nzerow_to_master, &
                                                            sizes_numrowlocal, sizes_nzelocal
      LOGICAL                                            :: bin, ut
      REAL(KIND=real_4), ALLOCATABLE, DIMENSION(:)       :: nzval_to_master_r_sp
      REAL(KIND=real_8)                                  :: thld
      REAL(KIND=real_8), ALLOCATABLE, DIMENSION(:)       :: nzval_to_master_r_dp

      CALL timeset(routineN, handle)

      IF (PRESENT(upper_triangle)) THEN
         ut = upper_triangle
      ELSE
         ut = .FALSE.
      END IF

      IF (PRESENT(threshold)) THEN
         thld = threshold
      ELSE
         thld = 0.0_dp
      END IF

      IF (PRESENT(binary)) THEN
         bin = binary
      ELSE
         bin = .FALSE.
      END IF

      IF (.NOT. csr_mat%valid) &
         DBCSR_ABORT("can not write invalid CSR matrix")

      tag1 = 0
      tag2 = 1
      tag3 = 2

      CALL mp_environ(numnodes, mynode, csr_mat%mp_group)

      ! gather sizes (number of local non-zero elements and number of local rows)
      ALLOCATE (sizes_nzelocal(numnodes))
      ALLOCATE (sizes_numrowlocal(numnodes))

      CALL mp_gather(csr_mat%nze_local, sizes_nzelocal, 0, csr_mat%mp_group)
      CALL mp_gather(csr_mat%nrows_local, sizes_numrowlocal, 0, csr_mat%mp_group)

      ! for each node, send matrix data to node 0 (master) and write data
      DO i = 0, numnodes - 1
         ii = i
         IF (mynode .EQ. 0) THEN ! allocations for receiving data from node i
            ALLOCATE (colind_to_master(sizes_nzelocal(ii + 1)))
            ALLOCATE (nzerow_to_master(sizes_numrowlocal(ii + 1)))

            SELECT CASE (csr_mat%nzval_local%data_type)
            CASE (dbcsr_type_real_4)
               data_format = "(2(I8),E23.6E2)"
               ALLOCATE (nzval_to_master_r_sp(sizes_nzelocal(ii + 1)))
            CASE (dbcsr_type_real_8)
               data_format = "(2(I8),E23.14E3)"
               ALLOCATE (nzval_to_master_r_dp(sizes_nzelocal(ii + 1)))
            CASE (dbcsr_type_complex_4)
               data_format = "(2(I8),2(E23.6E2))"
               ALLOCATE (nzval_to_master_c_sp(sizes_nzelocal(ii + 1)))
            CASE (dbcsr_type_complex_8)
               data_format = "(2(I8),2(E23.14E3))"
               ALLOCATE (nzval_to_master_c_dp(sizes_nzelocal(ii + 1)))
            END SELECT
         END IF

         IF (mynode .EQ. 0) THEN ! receive at node 0
            IF (ii .EQ. 0) THEN ! data from node 0, no need for mpi routines
               colind_to_master(:) = csr_mat%colind_local(:)
               nzerow_to_master(:) = csr_mat%nzerow_local(:)
               SELECT CASE (csr_mat%nzval_local%data_type)
               CASE (dbcsr_type_real_4)
                  nzval_to_master_r_sp(:) = csr_mat%nzval_local%r_sp(:)
               CASE (dbcsr_type_real_8)
                  nzval_to_master_r_dp(:) = csr_mat%nzval_local%r_dp(:)
               CASE (dbcsr_type_complex_4)
                  nzval_to_master_c_sp(:) = csr_mat%nzval_local%c_sp(:)
               CASE (dbcsr_type_complex_8)
                  nzval_to_master_c_dp(:) = csr_mat%nzval_local%c_dp(:)
               END SELECT
            ELSE ! receive data from nodes with rank > 0
               CALL mp_recv(colind_to_master, ii, tag1, csr_mat%mp_group)
               CALL mp_recv(nzerow_to_master, ii, tag2, csr_mat%mp_group)
               SELECT CASE (csr_mat%nzval_local%data_type)
               CASE (dbcsr_type_real_4)
                  CALL mp_recv(nzval_to_master_r_sp, ii, tag3, csr_mat%mp_group)
               CASE (dbcsr_type_real_8)
                  CALL mp_recv(nzval_to_master_r_dp, ii, tag3, csr_mat%mp_group)
               CASE (dbcsr_type_complex_4)
                  CALL mp_recv(nzval_to_master_c_sp, ii, tag3, csr_mat%mp_group)
               CASE (dbcsr_type_complex_8)
                  CALL mp_recv(nzval_to_master_c_dp, ii, tag3, csr_mat%mp_group)
               END SELECT
            END IF
         END IF

         IF ((mynode .EQ. ii) .AND. (ii .NE. 0)) THEN ! send from nodes with rank > 0
            CALL mp_send(csr_mat%colind_local, 0, tag1, csr_mat%mp_group)
            CALL mp_send(csr_mat%nzerow_local, 0, tag2, csr_mat%mp_group)
            SELECT CASE (csr_mat%nzval_local%data_type)
            CASE (dbcsr_type_real_4)
               CALL mp_send(csr_mat%nzval_local%r_sp, 0, tag3, csr_mat%mp_group)
            CASE (dbcsr_type_real_8)
               CALL mp_send(csr_mat%nzval_local%r_dp, 0, tag3, csr_mat%mp_group)
            CASE (dbcsr_type_complex_4)
               CALL mp_send(csr_mat%nzval_local%c_sp, 0, tag3, csr_mat%mp_group)
            CASE (dbcsr_type_complex_8)
               CALL mp_send(csr_mat%nzval_local%c_dp, 0, tag3, csr_mat%mp_group)
            END SELECT
         END IF

         IF (mynode .EQ. 0) THEN ! write data received at node 0
            !WRITE(unit_nr,"(A27)") "#row ind, col ind, value"
            m = 0
            DO k = 1, sizes_numrowlocal(ii + 1)
               rowind = k + SUM(sizes_numrowlocal(1:ii)) ! row index: local to global
               DO l = 1, nzerow_to_master(k)
                  m = m + 1
                  IF ((.NOT. ut) .OR. (rowind .LE. colind_to_master(m))) THEN
                     SELECT CASE (csr_mat%nzval_local%data_type)
                     CASE (dbcsr_type_real_4)
                        IF (ABS(nzval_to_master_r_sp(m)) .GE. thld) THEN
                           IF (bin) THEN
                              WRITE (unit_nr) rowind, colind_to_master(m), nzval_to_master_r_sp(m)
                           ELSE
                              WRITE (unit_nr, data_format) rowind, colind_to_master(m), &
                                 nzval_to_master_r_sp(m)
                           END IF
                        END IF
                     CASE (dbcsr_type_real_8)
                        IF (ABS(nzval_to_master_r_dp(m)) .GE. thld) THEN
                           IF (bin) THEN
                              WRITE (unit_nr) rowind, colind_to_master(m), nzval_to_master_r_dp(m)
                           ELSE
                              WRITE (unit_nr, data_format) rowind, colind_to_master(m), &
                                 nzval_to_master_r_dp(m)
                           END IF
                        END IF
                     CASE (dbcsr_type_complex_4)
                        IF (ABS(nzval_to_master_c_sp(m)) .GE. thld) THEN
                           IF (bin) THEN
                              WRITE (unit_nr) rowind, colind_to_master(m), nzval_to_master_c_sp(m)
                           ELSE
                              WRITE (unit_nr, data_format) rowind, colind_to_master(m), &
                                 nzval_to_master_c_sp(m)
                           END IF
                        END IF
                     CASE (dbcsr_type_complex_8)
                        IF (ABS(nzval_to_master_c_dp(m)) .GE. thld) THEN
                           IF (bin) THEN
                              WRITE (unit_nr) rowind, colind_to_master(m), nzval_to_master_c_dp(m)
                           ELSE
                              WRITE (unit_nr, data_format) rowind, colind_to_master(m), &
                                 nzval_to_master_c_dp(m)
                           END IF
                        END IF
                     END SELECT
                  END IF
               END DO
            END DO

            DEALLOCATE (colind_to_master)
            DEALLOCATE (nzerow_to_master)

            SELECT CASE (csr_mat%nzval_local%data_type)
            CASE (dbcsr_type_real_4)
               DEALLOCATE (nzval_to_master_r_sp)
            CASE (dbcsr_type_real_8)
               DEALLOCATE (nzval_to_master_r_dp)
            CASE (dbcsr_type_complex_4)
               DEALLOCATE (nzval_to_master_c_sp)
            CASE (dbcsr_type_complex_8)
               DEALLOCATE (nzval_to_master_c_dp)
            END SELECT
         END IF
      END DO

      CALL timestop(handle)

   END SUBROUTINE csr_write

   SUBROUTINE csr_print_sparsity(csr_mat, unit_nr)
      !! Print CSR sparsity
      TYPE(csr_type), INTENT(IN)                         :: csr_mat
      INTEGER, INTENT(IN)                                :: unit_nr

      CHARACTER(LEN=*), PARAMETER :: routineN = 'csr_print_sparsity'

      INTEGER                                            :: handle, mynode, numnodes
      INTEGER(KIND=int_8)                                :: dbcsr_nze_total
      REAL(KIND=real_8)                                  :: dbcsr_nze_percentage, nze_percentage

      CALL timeset(routineN, handle)

      IF (.NOT. csr_mat%valid) &
         DBCSR_ABORT("CSR matrix must be created first")

      nze_percentage = 100.0_dp*(REAL(csr_mat%nze_total, KIND=real_8) &
                                 /REAL(csr_mat%nrows_total, KIND=real_8)) &
                       /REAL(csr_mat%ncols_total, KIND=real_8)

      IF (csr_mat%has_mapping) THEN
         dbcsr_nze_total = dbcsr_get_nze(csr_mat%dbcsr_mapping%brd_mat)
         CALL mp_sum(dbcsr_nze_total, csr_mat%mp_group)
         dbcsr_nze_percentage = 100.0_dp*(REAL(dbcsr_nze_total, KIND=real_8) &
                                          /REAL(csr_mat%nrows_total, KIND=real_8)) &
                                /REAL(csr_mat%ncols_total, KIND=real_8)
      END IF

      CALL mp_environ(numnodes, mynode, csr_mat%mp_group)

      IF (mynode .EQ. 0) THEN
         WRITE (unit_nr, "(T15,A,T68,I13)") "Number of  CSR non-zero elements:", csr_mat%nze_total
         WRITE (unit_nr, "(T15,A,T75,F6.2)") "Percentage CSR non-zero elements:", nze_percentage
         !IF(csr_mat%has_mapping) THEN
         !  WRITE(unit_nr,"(T15,A,T75,F6.2/)") "Percentage DBCSR non-zero elements:", dbcsr_nze_percentage
         !ENDIF
      END IF

      CALL timestop(handle)
   END SUBROUTINE csr_print_sparsity

   SUBROUTINE dbcsr_create_brd(dbcsr_mat, brd_mat, equal_dist, floor_dist, numnodes)
      !! Converts a DBCSR matrix to a block row distributed matrix.

      TYPE(dbcsr_type), INTENT(IN)                       :: dbcsr_mat
         !! DBCSR matrix to be converted
      TYPE(dbcsr_type), INTENT(OUT)                      :: brd_mat
         !! converted matrix
      LOGICAL, INTENT(IN)                                :: equal_dist, floor_dist
         !! see documentation of csr_create_from_dbcsr
         !! see documentation of csr_create_from_dbcsr
      INTEGER, INTENT(IN)                                :: numnodes
         !! number of nodes to use for block row distribution

      CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_create_brd'

      CHARACTER                                          :: matrix_type
      CHARACTER(LEN=default_string_length)               :: matrix_name
      INTEGER :: cs, data_type, end_ind, handle, i, k, l, m, mynode, nblkcols_total, &
                 nblkrows_total, nfullrows_local, nfullrows_total, node_size, numnodes_total, row_index, &
                 split_row, start_ind
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: rdist_tmp, row_blk_size_new_tmp
      INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: pgrid
      INTEGER, DIMENSION(:), POINTER, CONTIGUOUS         :: cdist, col_blk_size, rdist, &
                                                            row_blk_size, row_blk_size_new
      REAL(KIND=real_8)                                  :: chunk_size
      TYPE(dbcsr_distribution_obj)                       :: dist_current, dist_new
      TYPE(dbcsr_mp_obj)                                 :: mp_obj_current, mp_obj_new
      TYPE(mp_comm_type)                                 :: mp_group

      CALL timeset(routineN, handle)

      NULLIFY (row_blk_size, rdist, row_blk_size_new)
      CALL dbcsr_get_info(dbcsr_mat, &
                          nblkrows_total=nblkrows_total, &
                          nblkcols_total=nblkcols_total, &
                          nfullrows_total=nfullrows_total, &
                          row_blk_size=row_blk_size, &
                          col_blk_size=col_blk_size, &
                          matrix_type=matrix_type, &
                          data_type=data_type)

      matrix_name = dbcsr_name(dbcsr_mat)

      ALLOCATE (cdist(nblkcols_total))
      cdist = 0

      dist_current = dbcsr_distribution(dbcsr_mat)
      mp_obj_current = dbcsr_distribution_mp(dist_current)
      mp_group = dbcsr_mp_group(mp_obj_current)
      mynode = dbcsr_mp_mynode(mp_obj_current)
      numnodes_total = dbcsr_mp_numnodes(mp_obj_current)

      ALLOCATE (pgrid(numnodes_total, 1))

      IF (equal_dist) THEN ! Equally distribute rows over processors -> cut blocks

         ! Calculate the number of rows a processor can hold
         IF (floor_dist) THEN
            nfullrows_local = FLOOR(REAL(nfullrows_total, KIND=dp)/numnodes)
         ELSE
            nfullrows_local = CEILING(REAL(nfullrows_total, KIND=dp)/numnodes)
         END IF

         ! allocate maximum amount of memory possibly needed
         ALLOCATE (rdist_tmp(nblkrows_total + numnodes - 1)) ! row distribution
         ALLOCATE (row_blk_size_new_tmp(nblkrows_total + numnodes - 1)) ! new sizes of block rows

         k = 0 ! counter for block rows
         m = 0 ! node counter
         node_size = nfullrows_local ! space available on current node in number of rows
         IF (node_size .GT. 0) THEN
            DO l = 1, nblkrows_total
               split_row = row_blk_size(l) ! size of current block row (number of rows)
               DO WHILE (split_row .GE. node_size) ! cut block row and send it to two nodes
                  k = k + 1
                  m = m + 1
                  row_blk_size_new_tmp(k) = node_size ! size of first part of block row
                  rdist_tmp(k) = m - 1 ! send first part to node m
                  split_row = split_row - node_size ! size of remaining part of block rows
                  node_size = nfullrows_local ! space available on next node
                  IF (floor_dist .AND. (m .EQ. numnodes - 1)) THEN ! send all remaining rows to last node
                     node_size = nfullrows_total - (numnodes - 1)*node_size
                  END IF
               END DO
               IF (split_row .GT. 0) THEN ! enough space left on next node for remaining rows
                  k = k + 1
                  row_blk_size_new_tmp(k) = split_row ! size of remaining part of block row
                  rdist_tmp(k) = m ! send to next node
                  node_size = node_size - split_row ! remaining space on next node
               END IF
            END DO
         ELSE ! send everything to last node if node_size = 0
            rdist_tmp(1:nblkrows_total) = numnodes - 1
            row_blk_size_new_tmp(1:nblkrows_total) = row_blk_size ! row blocks unchanged
            k = nblkrows_total
         END IF

         ! Copy data to correctly allocated variables
         ALLOCATE (row_blk_size_new(k))
         row_blk_size_new = row_blk_size_new_tmp(1:k)
         ALLOCATE (rdist(k))
         rdist = rdist_tmp(1:k)

      ELSE ! Leave block rows intact (do not cut)
         ALLOCATE (rdist(nblkrows_total))
         rdist = 0
         IF (numnodes .GT. nblkrows_total) THEN
            rdist = (/(i, i=0, nblkrows_total - 1)/)
         ELSE
            chunk_size = REAL(nblkrows_total, KIND=dp)/numnodes
            row_index = 0
            start_ind = 1
            DO i = 0, numnodes - 1
               cs = NINT(i*chunk_size) - NINT((i - 1)*chunk_size)
               end_ind = MIN(start_ind - 1 + cs, nblkrows_total)
               rdist(start_ind:end_ind) = row_index
               start_ind = end_ind + 1
               row_index = row_index + 1
            END DO
         END IF
         row_blk_size_new => row_blk_size
      END IF

      pgrid(:, :) = RESHAPE((/(i, i=0, numnodes_total - 1)/), (/numnodes_total, 1/))
      CALL dbcsr_mp_new(mp_obj_new, mp_group, pgrid, mynode, numnodes=numnodes_total)
      CALL dbcsr_distribution_new(dist_new, mp_obj_new, rdist, cdist, reuse_arrays=.TRUE.)

      CALL dbcsr_create(brd_mat, TRIM(matrix_name)//" row-block distributed", &
                        dist_new, matrix_type, row_blk_size_new, col_blk_size, data_type=data_type)
      CALL dbcsr_complete_redistribute(dbcsr_mat, brd_mat)

      DEALLOCATE (pgrid)

      IF (equal_dist) DEALLOCATE (row_blk_size_new)

      CALL dbcsr_distribution_release(dist_new)
      CALL dbcsr_mp_release(mp_obj_new)

      CALL timestop(handle)

   END SUBROUTINE dbcsr_create_brd

END MODULE dbcsr_csr_conversions
